{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRIAL_NAME='ResNet_baseline'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "CONF={\n",
    "    'niter':200,\n",
    "    'GPU':0,\n",
    "    'BS':128,\n",
    "    'test_BS':256,\n",
    "    'N_neg':3,\n",
    "    'name':TRIAL_NAME,\n",
    "    'tb_dir':os.path.join('./runs', TRIAL_NAME),\n",
    "    'nz':64,\n",
    "    'seed':10708,\n",
    "    'data_dir':'/DataSet/COCO',\n",
    "    'dataType':'train2017',\n",
    "    'valType':'val2017',\n",
    "    'LAMBDA':0.5,\n",
    "    'use_super':False,\n",
    "    'test_classes':80\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=str(CONF['GPU'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clear any logs from previous runs\n",
    "time.sleep(2)\n",
    "import shutil\n",
    "shutil.rmtree(CONF['tb_dir'], ignore_errors=True)\n",
    "time.sleep(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import texar.torch as tx\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.parallel\n",
    "from torch.nn import functional as F\n",
    "import torch.backends.cudnn as cudnn\n",
    "import torch.optim as optim\n",
    "import torch.utils.data\n",
    "import torchvision\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "import torchvision.datasets as dset\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.utils as vutils\n",
    "import numpy as np\n",
    "from torch import autograd\n",
    "import multiprocessing\n",
    "from PIL import Image\n",
    "from sklearn import metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if True else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "writer = SummaryWriter(log_dir=CONF['tb_dir'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(CONF['seed'])\n",
    "torch.manual_seed(CONF['seed'])\n",
    "np.random.seed(CONF['seed'])\n",
    "cudnn.benchmark = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = CONF['BS']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "T = transforms.Compose([\n",
    "    transforms.RandomResizedCrop((256,256), scale=(0.3, 1.0), ratio=(0.75, 1.3333333333333333)),\n",
    "    transforms.ColorJitter(brightness=.1, contrast=.05, saturation=.05, hue=.05),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor()\n",
    "])\n",
    "T_test = transforms.Compose([\n",
    "    transforms.Resize((256,256)),\n",
    "    transforms.ToTensor()\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.datasets.vision import VisionDataset\n",
    "class CocoClassification(VisionDataset):\n",
    "    \"\"\"`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.\n",
    "\n",
    "    Args:\n",
    "        root (string): Root directory where images are downloaded to.\n",
    "        annFile (string): Path to json annotation file.\n",
    "        transform (callable, optional): A function/transform that  takes in an PIL image\n",
    "            and returns a transformed version. E.g, ``transforms.ToTensor``\n",
    "        target_transform (callable, optional): A function/transform that takes in the\n",
    "            target and transforms it.\n",
    "        transforms (callable, optional): A function/transform that takes input sample and its target as entry\n",
    "            and returns a transformed version.\n",
    "    \"\"\"\n",
    "\n",
    "    def sample_class(self, k):\n",
    "        if CONF['use_super']:\n",
    "            self.classes = np.array([\"vehicle\", \"outdoor\", \"indoor\", \"person\", \"appliance\", \"furniture\", \"sports\", \"food\", \"kitchen\", \"accessory\", \"electronic\", \"animal\"])#np.arange(12)+1\n",
    "            self.class_description = [\"vehicle\", \"outdoor\", \"indoor\", \"person\", \"appliance\", \"furniture\", \"sports\", \"food\", \"kitchen\", \"accessory\", \"electronic\", \"animal\"]\n",
    "            return\n",
    "        class_list = self.coco.getCatIds()\n",
    "        self.classes = np.sort(np.random.choice(class_list, size=k, replace=False))\n",
    "        self.class_description = self.coco.loadCats(self.classes)\n",
    "        arr = []\n",
    "        for catId in self.classes:\n",
    "            arr+=self.coco.getImgIds(catIds=[catId])\n",
    "        self.ids = sorted(list(set(arr)))\n",
    "        \n",
    "    def __init__(self, root, annFile, transform=None, target_transform=None, transforms=None):\n",
    "        super(CocoClassification, self).__init__(root, transforms, transform, target_transform)\n",
    "        from pycocotools.coco import COCO\n",
    "        self.coco = COCO(annFile)\n",
    "        self.sample_class(len(self.coco.getCatIds()))\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            index (int): Index\n",
    "\n",
    "        Returns:\n",
    "            tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``.\n",
    "        \"\"\"\n",
    "        coco = self.coco\n",
    "        img_id = self.ids[index]\n",
    "        ann_ids = coco.getAnnIds(imgIds=img_id)\n",
    "        cat_ids = [ann['category_id'] for ann in coco.loadAnns(ann_ids)]\n",
    "        target = coco.loadCats(cat_ids)\n",
    "        if CONF['use_super']:\n",
    "            target = np.array([x['supercategory'] for x in target])\n",
    "        else:\n",
    "            target = np.array([x['id'] for x in target if x['id'] in self.classes])\n",
    "        targets = torch.FloatTensor([1 if (c in target) else 0 for c in self.classes])\n",
    "        path = coco.loadImgs(img_id)[0]['file_name']\n",
    "        img = Image.open(os.path.join(self.root, path)).convert('RGB')\n",
    "        if self.transforms is not None:\n",
    "            img, targets = self.transforms(img, targets)\n",
    "\n",
    "        return img, targets\n",
    "\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.ids)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "clas_set = CocoClassification(root = '{}/{}'.format(CONF['data_dir'],CONF['dataType']),\n",
    "                        annFile = '{}/annotations/instances_{}.json'.format(CONF['data_dir'],CONF['dataType']),\n",
    "                        transform=T)\n",
    "val_set = CocoClassification(root = '{}/{}'.format(CONF['data_dir'],CONF['valType']),\n",
    "                        annFile = '{}/annotations/instances_{}.json'.format(CONF['data_dir'],CONF['valType']),\n",
    "                        transform=T_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clas_set.sample_class(CONF['test_classes'])\n",
    "train_loader = torch.utils.data.DataLoader(clas_set, batch_size=CONF['BS'],shuffle=True, num_workers=8, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_loader = torch.utils.data.DataLoader(val_set, batch_size=CONF['test_BS'],shuffle=False, num_workers=8, pin_memory=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Flatten(nn.Module):\n",
    "    def forward(self, x):\n",
    "        x = x.view(x.size()[0], -1)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resnet18 = torchvision.models.resnet18(pretrained=True)\n",
    "modules=list(resnet18.children())[:-1]\n",
    "modules.append(Flatten())\n",
    "modules.append(nn.Linear(512, CONF['test_classes']))\n",
    "resnet = nn.Sequential(*modules)\n",
    "resnet.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = optim.SGD(resnet.parameters(), lr=0.02, momentum=0.9)\n",
    "scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, factor=0.2, patience=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.BCEWithLogitsLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm, trange"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hamming_score(y_true, y_pred, normalize=True, sample_weight=None):\n",
    "    '''\n",
    "    Compute the Hamming score (a.k.a. label-based accuracy) for the multi-label case\n",
    "    https://stackoverflow.com/q/32239577/395857\n",
    "    '''\n",
    "    acc_list = []\n",
    "    for i in range(y_true.shape[0]):\n",
    "        set_true = set( np.where(y_true[i])[0] )\n",
    "        set_pred = set( np.where(y_pred[i])[0] )\n",
    "        tmp_a = None\n",
    "        if len(set_true) == 0 and len(set_pred) == 0:\n",
    "            tmp_a = 1\n",
    "        else:\n",
    "            tmp_a = len(set_true.intersection(set_pred))/\\\n",
    "                    float( len(set_true.union(set_pred)) )\n",
    "        acc_list.append(tmp_a)\n",
    "    return np.mean(acc_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for it in trange(CONF['niter']):\n",
    "\n",
    "    l=[]\n",
    "    LBL=[]\n",
    "    activation = []\n",
    "    Y = []\n",
    "    resnet.eval()\n",
    "    with torch.no_grad():\n",
    "        count = 0\n",
    "        corrects = torch.zeros(CONF['test_classes'])\n",
    "        for img,lbl in tqdm(val_loader, leave=False, desc=\"testing\"):\n",
    "            LBL.append(lbl.numpy())\n",
    "            lbl=lbl.cuda()\n",
    "            count += img.shape[0]\n",
    "            pred = resnet(img.cuda())\n",
    "            l.append(criterion(pred, lbl).item())\n",
    "            activation.append(pred.data.cpu().numpy())\n",
    "            pred = torch.sigmoid(pred)>.5\n",
    "            Y.append(pred.data.cpu().numpy())\n",
    "            corrects += torch.sum(torch.eq(pred, lbl), dim=0).cpu()\n",
    "        acc = (corrects/float(count))\n",
    "        writer.add_scalar(\"sup_acc/val_avg\", torch.mean(acc).item(), global_step=it)\n",
    "        writer.add_histogram('baseline/acc_val', acc.data.cpu().numpy(), global_step=it)\n",
    "        writer.add_scalar(\"sup_loss/val_loss\", np.average(l), global_step=it)\n",
    "        writer.flush()\n",
    "\n",
    "    Y, LBL, activation = np.concatenate(Y,axis=0), np.concatenate(LBL,axis=0), np.concatenate(activation,axis=0)\n",
    "    writer.add_scalar(\"sup_metrics/subset_acc\", metrics.accuracy_score(LBL, Y, normalize=True), global_step=it)\n",
    "    writer.add_scalar(\"sup_metrics/hamming_loss\", metrics.hamming_loss(LBL, Y), global_step=it)\n",
    "    writer.add_scalar(\"sup_metrics/hamming_score\", hamming_score(LBL, Y), global_step=it)\n",
    "    writer.add_scalar(\"sup_metrics/micro_f1\", metrics.f1_score(LBL, Y, average='micro'), global_step=it)\n",
    "    writer.add_scalar(\"sup_metrics/macro_f1\", metrics.f1_score(LBL, Y, average='macro'), global_step=it)\n",
    "    writer.add_scalar(\"sup_metrics/micro_roc_auc\", metrics.roc_auc_score(LBL, Y, average='micro'), global_step=it)\n",
    "    writer.add_scalar(\"sup_metrics/macro_roc_auc\", metrics.roc_auc_score(LBL, Y, average='macro'), global_step=it)\n",
    "    writer.add_scalar(\"sup_metrics/micro_precision\", metrics.precision_score(LBL,Y,average='micro'), global_step=it)\n",
    "    writer.add_scalar(\"sup_metrics/macro_precision\", metrics.precision_score(LBL,Y,average='macro'), global_step=it)\n",
    "    writer.add_scalar(\"sup_metrics/micro_recall\", metrics.recall_score(LBL,Y,average='micro'), global_step=it)\n",
    "    writer.add_scalar(\"sup_metrics/macro_recall\", metrics.recall_score(LBL,Y,average='macro'), global_step=it)\n",
    "    writer.add_scalar(\"sup_metrics/avg_acc\", np.average(np.sum((Y==LBL), axis=0)/Y.shape[0]), global_step=it)\n",
    "    scheduler.step(np.average(l))\n",
    "    \n",
    "    count = 0\n",
    "    corrects = torch.zeros(CONF['test_classes'])\n",
    "    resnet.train()\n",
    "    for img,lbl in tqdm(train_loader, leave=False, desc=\"training\"):\n",
    "        count += img.shape[0]\n",
    "        img = img.cuda()\n",
    "        pred = resnet(img)\n",
    "        loss = criterion(pred,lbl.cuda())\n",
    "        with torch.no_grad():\n",
    "            corrects += torch.sum(torch.eq((torch.sigmoid(pred)>.5).cpu(), lbl), dim=0)\n",
    "        opt.zero_grad()\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        writer.add_scalar(\"sup_loss/loss\", loss.item(), global_step=c)\n",
    "        c+=1\n",
    "    acc = (corrects/float(count))\n",
    "    writer.add_scalar(\"sup_acc/train_avg\", torch.mean(acc).item(), global_step=it)\n",
    "    writer.add_histogram('baseline/acc_train', acc.data.cpu().numpy(), global_step=it)\n",
    "\n",
    "writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
